{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc539d4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright 2025 Google LLC\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1f25b113",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "%pip install -U -q 'google-genai>=1.0.0'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "09a81a73",
   "metadata": {},
   "outputs": [],
   "source": [
    "from google import genai\n",
    "from IPython.display import Markdown"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf2eab8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "GOOGLE_API_KEY = \"\"\n",
    "\n",
    "client = genai.Client(api_key=GOOGLE_API_KEY)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f51eb3cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_ID = \"gemini-2.0-pro\"\n",
    "target_model = \"Gemma3\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7908d62",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Uploaded file 'files/96zbvn0jl8v9' as: https://generativelanguage.googleapis.com/v1beta/files/96zbvn0jl8v9\n",
      "Uploaded file 'files/i81ci5tcyjwa' as: https://generativelanguage.googleapis.com/v1beta/files/i81ci5tcyjwa\n",
      "Uploaded file 'files/poo15cv7l54e' as: https://generativelanguage.googleapis.com/v1beta/files/poo15cv7l54e\n",
      "Uploaded file 'files/84qrwp42q92e' as: https://generativelanguage.googleapis.com/v1beta/files/84qrwp42q92e\n"
     ]
    }
   ],
   "source": [
    "param_file = client.files.upload(file=\"context/param_mapping.py\")\n",
    "shape_file = client.files.upload(file=\"context/hf_shape.py\")\n",
    "\n",
    "print(f\"Uploaded file '{param_file.name}' as: {param_file.uri}\")\n",
    "print(f\"Uploaded file '{shape_file.name}' as: {shape_file.uri}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a8b3dcf0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "```python\n",
       "\"\"\"\n",
       " Copyright 2025 Google LLC\n",
       "\n",
       " Licensed under the Apache License, Version 2.0 (the \"License\");\n",
       " you may not use this file except in compliance with the License.\n",
       " You may obtain a copy of the License at\n",
       "\n",
       "      https://www.apache.org/licenses/LICENSE-2.0\n",
       "\n",
       " Unless required by applicable law or agreed to in writing, software\n",
       " distributed under the License is distributed on an \"AS IS\" BASIS,\n",
       " WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
       " See the License for the specific language governing permissions and\n",
       " limitations under the License.\n",
       " \"\"\"\n",
       "\n",
       "import numpy as np\n",
       "import jax\n",
       "import jax.numpy as jnp\n",
       "\n",
       "\n",
       "def Gemma3_MAXTEXT_TO_HF_PARAM_MAPPING(config, scan_layers=False):\n",
       "  \"\"\"Returns mapping between MaxText and HuggingFace Gemma3 weight paths.\n",
       "\n",
       "  Args:\n",
       "      config (dict): Model configuration dictionary containing at least 'num_hidden_layers'.\n",
       "      scan_layers (bool, optional): Whether the MaxText model uses layer scanning optimization.\n",
       "          When True, decoder layers are stacked into a single tensor [dim1, #layers, dim2].\n",
       "          Defaults to False.\n",
       "\n",
       "  Returns:\n",
       "      dict: A mapping where:\n",
       "          - Keys are MaxText parameter paths\n",
       "          - Values are either:\n",
       "              - Single strings (HF parameter path) for unscanned parameters\n",
       "              - Lists of strings (HF parameter paths) for stacked layers when scan_layers=True\n",
       "  \"\"\"\n",
       "\n",
       "  nlayers = config[\"num_hidden_layers\"]\n",
       "  mapping = {\n",
       "      \"params-token_embedder-embedding\": \"model.embed_tokens.weight\",\n",
       "      \"params-decoder-decoder_norm-scale\": \"model.norm.weight\",\n",
       "  }\n",
       "  if scan_layers:\n",
       "    mapping = {\n",
       "        **mapping,\n",
       "        \"params-decoder-layers-attention-key-kernel\": [\n",
       "            f\"model.layers.{i}.self_attn.k_proj.weight\" for i in range(nlayers)\n",
       "        ],\n",
       "        \"params-decoder-layers-attention-value-kernel\": [\n",
       "            f\"model.layers.{i}.self_attn.v_proj.weight\" for i in range(nlayers)\n",
       "        ],\n",
       "        \"params-decoder-layers-attention-query-kernel\": [\n",
       "            f\"model.layers.{i}.self_attn.q_proj.weight\" for i in range(nlayers)\n",
       "        ],\n",
       "        \"params-decoder-layers-attention-out-kernel\": [\n",
       "            f\"model.layers.{i}.self_attn.o_proj.weight\" for i in range(nlayers)\n",
       "        ],\n",
       "        \"params-decoder-layers-mlp-wi_0-kernel\": [\n",
       "            f\"model.layers.{i}.mlp.gate_proj.weight\" for i in range(nlayers)\n",
       "        ],\n",
       "        \"params-decoder-layers-mlp-wi_1-kernel\": [\n",
       "            f\"model.layers.{i}.mlp.up_proj.weight\" for i in range(nlayers)\n",
       "        ],\n",
       "        \"params-decoder-layers-mlp-wo-kernel\": [\n",
       "            f\"model.layers.{i}.mlp.down_proj.weight\" for i in range(nlayers)\n",
       "        ],\n",
       "        \"params-decoder-layers-rms_norm-scale\": [\n",
       "            f\"model.layers.{i}.input_layernorm.weight\" for i in range(nlayers)\n",
       "        ],\n",
       "        \"params-decoder-layers-ffn_rms_norm-scale\": [\n",
       "            f\"model.layers.{i}.post_attention_layernorm.weight\" for i in range(nlayers)\n",
       "        ],\n",
       "    }\n",
       "  else:\n",
       "    for layer_idx in range(nlayers):\n",
       "      layer_mapping = {\n",
       "          f\"params-decoder-layers_{layer_idx}-attention-key-kernel\": f\"model.layers.{layer_idx}.self_attn.k_proj.weight\",\n",
       "          f\"params-decoder-layers_{layer_idx}-attention-value-kernel\": f\"model.layers.{layer_idx}.self_attn.v_proj.weight\",\n",
       "          f\"params-decoder-layers_{layer_idx}-attention-query-kernel\": f\"model.layers.{layer_idx}.self_attn.q_proj.weight\",\n",
       "          f\"params-decoder-layers_{layer_idx}-attention-out-kernel\": f\"model.layers.{layer_idx}.self_attn.o_proj.weight\",\n",
       "          f\"params-decoder-layers_{layer_idx}-mlp-wi_0-kernel\": f\"model.layers.{layer_idx}.mlp.gate_proj.weight\",\n",
       "          f\"params-decoder-layers_{layer_idx}-mlp-wi_1-kernel\": f\"model.layers.{layer_idx}.mlp.up_proj.weight\",\n",
       "          f\"params-decoder-layers_{layer_idx}-mlp-wo-kernel\": f\"model.layers.{layer_idx}.mlp.down_proj.weight\",\n",
       "          f\"params-decoder-layers_{layer_idx}-rms_norm-scale\": f\"model.layers.{layer_idx}.input_layernorm.weight\",\n",
       "          f\"params-decoder-layers_{layer_idx}-ffn_rms_norm-scale\": f\"model.layers.{layer_idx}.post_attention_layernorm.weight\",\n",
       "      }\n",
       "      mapping = {**mapping, **layer_mapping}\n",
       "  return mapping\n",
       "\n",
       "\n",
       "def Gemma3_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, scan_layers=False, saving_to_hf=False):\n",
       "  \"\"\"Creates parameter transformation functions for converting between MaxText and\n",
       "  HuggingFace formats.\n",
       "\n",
       "  This function generates a mapping of transformation functions that handle the necessary\n",
       "  conversions between MaxText and HuggingFace parameter formats, including operations like\n",
       "  padding, reshaping, and scaling.\n",
       "\n",
       "  Args:\n",
       "      config (dict): Model configuration dictionary that must contain:\n",
       "          - num_hidden_layers (int): Number of layers in the model\n",
       "          - head_dim (int): Dimension of attention heads\n",
       "          - hidden_size (int): Model's hidden dimension size\n",
       "\n",
       "      scan_layers (bool, optional): Controls the output format for layer parameters:\n",
       "          - True: Returns transformation functions for batched layer parameters\n",
       "          - False: Returns transformation functions for individual layer parameters\n",
       "          Defaults to False.\n",
       "\n",
       "      saving_to_hf (bool, optional): Determines the direction of transformation:\n",
       "          - True: MaxText → HuggingFace conversion\n",
       "          - False: HuggingFace → MaxText conversion\n",
       "          Defaults to False.\n",
       "\n",
       "  Returns:\n",
       "      dict: Parameter transformation mapping where:\n",
       "          - Keys: MaxText parameter names (str)\n",
       "          - Values: Either:\n",
       "              - callable: Single transformation function\n",
       "              - list[callable]: List of transformation functions to be applied in sequence\n",
       "\n",
       "  Transformation Details:\n",
       "      The function handles several types of parameter transformations:\n",
       "      1. Embedding layer padding:\n",
       "          - HF shape: [vocab_size, d_model]\n",
       "          - MaxText shape: [padded_vocab_size, d_model] (padded for performance)\n",
       "      2. Layer normalization scaling:\n",
       "          - Adds/subtracts 1.0 depending on direction\n",
       "      3. Attention query scaling:\n",
       "          - Scales by sqrt(head_dim) or its inverse\n",
       "\n",
       "      4. Kernel reshaping:\n",
       "          - Handles dimension transposition and reshaping between formats\n",
       "  \"\"\"\n",
       "  nlayers = config[\"num_hidden_layers\"]\n",
       "\n",
       "  def pad_hf_embedding_layer(input_tensor, target_shape):\n",
       "    \"\"\"Pads the HF embedding layer to match the MaxText embedding layer's shape.\n",
       "\n",
       "    Note:\n",
       "        HF embedding weights shape =  [vocab_size,d_model]\n",
       "        MaxText embedding weights shape = [padded_vocab_size,d_model]\n",
       "        MaxText pad Gemma3 embedding to padded_vocab_size for better performance.\n",
       "    \"\"\"\n",
       "    # TODO(wenxindongwork), Perhaps, this dtype should be the activation dtype\n",
       "    normalizer = np.dtype(\"float32\").type(config[\"hidden_size\"] ** 0.5)\n",
       "\n",
       "    def to_hf():\n",
       "      target_tensor = input_tensor[: target_shape[0], : target_shape[1]]\n",
       "      # target_tensor = target_tensor / normalizer  # no scale factor for embedding\n",
       "      target_tensor = target_tensor.astype(input_tensor.dtype)\n",
       "      return target_tensor\n",
       "\n",
       "    def from_hf():\n",
       "      target_tensor = np.zeros(target_shape, dtype=input_tensor.dtype)\n",
       "      target_tensor[: input_tensor.shape[0], : input_tensor.shape[1]] = input_tensor\n",
       "      # target_tensor = target_tensor * normalizer # no scale factor for embedding\n",
       "      target_tensor = target_tensor.astype(input_tensor.dtype)\n",
       "      return target_tensor\n",
       "\n",
       "    if saving_to_hf:\n",
       "      return to_hf()\n",
       "    else:\n",
       "      return from_hf()\n",
       "\n",
       "  def reshape_kernel(input_tensor, target_shape):\n",
       "    def to_hf():\n",
       "      flipped_target_shape = np.flip(np.array(target_shape))\n",
       "      return input_tensor.reshape(flipped_target_shape).T\n",
       "\n",
       "    def from_hf():\n",
       "      return input_tensor.T.reshape(target_shape)\n",
       "\n",
       "    if saving_to_hf:\n",
       "      return to_hf()\n",
       "    else:\n",
       "      return from_hf()\n",
       "\n",
       "  def scale_rmsnorm_layer(input_tensor, target_shape):\n",
       "    def to_hf():\n",
       "      return (input_tensor - 1.0).reshape(target_shape)\n",
       "\n",
       "    def from_hf():\n",
       "      return (input_tensor + 1.0).reshape(target_shape)\n",
       "\n",
       "    if saving_to_hf:\n",
       "      return to_hf()\n",
       "    else:\n",
       "      return from_hf()\n",
       "\n",
       "  def scale_query_layer(input_tensor, target_shape):\n",
       "    def to_hf():\n",
       "      depth_scale = np.dtype(\"float32\").type(np.sqrt(config[\"head_dim\"]))\n",
       "      return (input_tensor * depth_scale).astype(input_tensor.dtype)\n",
       "\n",
       "    def from_hf():\n",
       "      depth_scale = np.dtype(\"float32\").type(1 / np.sqrt(config[\"head_dim\"]))\n",
       "      return (input_tensor * depth_scale).astype(input_tensor.dtype)\n",
       "\n",
       "    if saving_to_hf:\n",
       "      return to_hf()\n",
       "    else:\n",
       "      return from_hf()\n",
       "\n",
       "  mapping = {\n",
       "      \"params-token_embedder-embedding\": pad_hf_embedding_layer,\n",
       "      \"params-decoder-decoder_norm-scale\": scale_rmsnorm_layer,\n",
       "  }\n",
       "  if scan_layers:\n",
       "    mapping = {\n",
       "        **mapping,\n",
       "        \"params-decoder-layers-attention-query-kernel\": [\n",
       "            reshape_kernel,\n",
       "            scale_query_layer,\n",
       "        ],\n",
       "        \"params-decoder-layers-attention-key-kernel\": reshape_kernel,\n",
       "        \"params-decoder-layers-attention-value-kernel\": reshape_kernel,\n",
       "        \"params-decoder-layers-mlp-wo-kernel\": reshape_kernel,\n",
       "        \"params-decoder-layers-mlp-wi_1-kernel\": reshape_kernel,\n",
       "        \"params-decoder-layers-mlp-wi_0-kernel\": reshape_kernel,\n",
       "        \"params-decoder-layers-attention-out-kernel\": reshape_kernel,\n",
       "        \"params-decoder-layers-rms_norm-scale\": scale_rmsnorm_layer,\n",
       "        \"params-decoder-layers-ffn_rms_norm-scale\": scale_rmsnorm_layer,\n",
       "    }\n",
       "  else:\n",
       "    for layer_idx in range(nlayers):\n",
       "      mapping = {\n",
       "          **mapping,\n",
       "          f\"params-decoder-layers_{layer_idx}-attention-query-kernel\": [\n",
       "              reshape_kernel,\n",
       "              scale_query_layer,\n",
       "          ],\n",
       "          f\"params-decoder-layers_{layer_idx}-attention-key-kernel\": reshape_kernel,\n",
       "          f\"params-decoder-layers_{layer_idx}-attention-value-kernel\": reshape_kernel,\n",
       "          f\"params-decoder-layers_{layer_idx}-mlp-wo-kernel\": reshape_kernel,\n",
       "          f\"params-decoder-layers_{layer_idx}-mlp-wi_1-kernel\": reshape_kernel,\n",
       "          f\"params-decoder-layers_{layer_idx}-mlp-wi_0-kernel\": reshape_kernel,\n",
       "          f\"params-decoder-layers_{layer_idx}-attention-out-kernel\": reshape_kernel,\n",
       "          f\"params-decoder-layers_{layer_idx}-rms_norm-scale\": scale_rmsnorm_layer,\n",
       "          f\"params-decoder-layers_{layer_idx}-ffn_rms_norm-scale\": scale_rmsnorm_layer,\n",
       "      }\n",
       "  return mapping\n",
       "\n",
       "\n",
       "def Gemma3_HF_WEIGHTS_TO_SHAPE_MAPPING(config):\n",
       "  \"\"\"Returns mapping between HuggingFace weights path and weights shape.\n",
       "\n",
       "  Args:\n",
       "      config (dict): Model configuration dictionary, defined in `model_configs.py`\n",
       "\n",
       "  Returns:\n",
       "      dict: A mapping where:\n",
       "          - Keys are HuggingFace model parameter paths\n",
       "          - Values are parameter shape as a List\n",
       "  \"\"\"\n",
       "\n",
       "  mapping = {\n",
       "      \"model.embed_tokens.weight\": [config[\"vocab_size\"], config[\"hidden_size\"]],\n",
       "      \"model.norm.weight\": [config[\"hidden_size\"]],\n",
       "  }\n",
       "  for layer_idx in range(config[\"num_hidden_layers\"]):\n",
       "    layer_mapping = {\n",
       "        f\"model.layers.{layer_idx}.input_layernorm.weight\": [config[\"hidden_size\"]],\n",
       "        f\"model.layers.{layer_idx}.post_attention_layernorm.weight\": [config[\"hidden_size\"]],\n",
       "        f\"model.layers.{layer_idx}.self_attn.q_proj.weight\": [\n",
       "            config[\"num_attention_heads\"] * config[\"head_dim\"],\n",
       "            config[\"hidden_size\"],\n",
       "        ],\n",
       "        f\"model.layers.{layer_idx}.self_attn.k_proj.weight\": [\n",
       "            config[\"num_key_value_heads\"] * config[\"head_dim\"],\n",
       "            config[\"hidden_size\"],\n",
       "        ],\n",
       "        f\"model.layers.{layer_idx}.self_attn.v_proj.weight\": [\n",
       "            config[\"num_key_value_heads\"] * config[\"head_dim\"],\n",
       "            config[\"hidden_size\"],\n",
       "        ],\n",
       "        f\"model.layers.{layer_idx}.self_attn.o_proj.weight\": [\n",
       "            config[\"hidden_size\"],\n",
       "            config[\"num_attention_heads\"] * config[\"head_dim\"],\n",
       "        ],\n",
       "        f\"model.layers.{layer_idx}.mlp.gate_proj.weight\": [\n",
       "            config[\"intermediate_size\"],\n",
       "            config[\"hidden_size\"],\n",
       "        ],\n",
       "        f\"model.layers.{layer_idx}.mlp.up_proj.weight\": [\n",
       "            config[\"intermediate_size\"],\n",
       "            config[\"hidden_size\"],\n",
       "        ],\n",
       "        f\"model.layers.{layer_idx}.mlp.down_proj.weight\": [\n",
       "            config[\"hidden_size\"],\n",
       "            config[\"intermediate_size\"],\n",
       "        ],\n",
       "    }\n",
       "    mapping = {**mapping, **layer_mapping}\n",
       "  return mapping\n",
       "\n",
       "```"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prompt = f\"\"\"\n",
    "  You are a code assist to help me find the checkpoint conversion from maxtext to huggingface. \n",
    "  The checkpoint does not fuse QKV vectors. \n",
    "  The transformer configs should be completely aligned with given model config for {target_model}\n",
    "  You need to generate the following code functions of {target_model} Model:\n",
    "    {target_model}_MAXTEXT_TO_HF_PARAM_MAPPING(); \n",
    "    {target_model}_MAXTEXT_TO_HF_PARAM_HOOK_FN();\n",
    "    {target_model}_HF_WEIGHTS_TO_SHAPE_MAPPING();\n",
    "\"\"\"\n",
    "\n",
    "response = client.models.generate_content(model=MODEL_ID, contents=[prompt, param_file, shape_file])\n",
    "\n",
    "Markdown(response.text)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "agent_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
